/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/

#pragma once

namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Varlen = true>
struct BlockInfo {
    template <typename Params>
    __device__ BlockInfo(const Params &params, const int bidb, const int bidh)
        : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
          sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
          actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q
                                                                    : params.cu_seqlens_q[bidb + 1] - sum_s_q),
          actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k
                                                                    : params.cu_seqlens_k[bidb + 1] - sum_s_k),
          row_shift(actual_seqlen_k - actual_seqlen_q),
          h_slope(1.0 / (exp2f(8.0 * (bidh + 1) / params.h) * params.scale_softmax)) {}

    template <typename index_t>
    inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
        return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
    }

    template <typename index_t>
    inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
        return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
    }

    const int sum_s_q;
    const int sum_s_k;
    const int actual_seqlen_q;
    const int actual_seqlen_k;
    const int row_shift;
    const float h_slope;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace flash
